Ukázka BERTopic

V R v rámci RStudia

Autor

Jan Netík

Načteme R balíky.

library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.4
✔ forcats   1.0.0     ✔ stringr   1.5.0
✔ ggplot2   3.4.4     ✔ tibble    3.2.1
✔ lubridate 1.9.3     ✔ tidyr     1.3.0
✔ purrr     1.0.2     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(here)
here() starts at /Users/netik/Documents/git/tmws
library(reticulate)

# pomocné funkce
source(here("shared.R"), local = TRUE)

Použijeme prostředí Pythonu, které jsme si nakonfigurovali dříve1.

use_condaenv("topic_modeling")

Načteme Python moduly.

bt <- import("bertopic")
bt_repre <- import("bertopic.representation")
st <- import("sentence_transformers")
sklearn_datasets <- import("sklearn.datasets")
cv <- import("sklearn.feature_extraction.text")

Data

Klasický dataset 20 newsgroups je distribuován v modulu sklearn.datesets. Jde o cca 20K krátkých zpráv, které si vyměňovali uživatelé Usenetu (tj. někdy v 90. letech). Zprávy pocházejí rovnoměrně z 20 kanálů.

Funkce rovnou umožňuje smazat různé technické části zpráv, které nás nezajímají. Nakonec z objektu vybereme pouze data.

newsgroups_raw <- sklearn_datasets$fetch_20newsgroups(
  subset = "all",
  remove = tuple("headers", "footers", "quotes")
)

# data samotná jsou ještě o úroveň níže
d_newsgroups <- newsgroups_raw[["data"]]

Trénujeme model

Získáme zvlášť embeddings.

# na Apple Silicon strojích chceme využít "grafickou kartu"
st_device <- NULL

if (version$platform == "aarch64-apple-darwin20") {
  st_device <- "mps"
}

# definujeme model
embedding_model <- st$SentenceTransformer("all-MiniLM-L6-v2", device = st_device)

# spočítáme embeddings, pokud už je nemáme
if (file.exists(here("data/embeddings/newsgroups.rds"))) {
  embeddings <- read_rds(here("data/embeddings/newsgroups.rds"))
} else {
  embeddings <- embedding_model$encode(d_newsgroups, show_progress_bar = TRUE)
}

Odhadneme zbytek modelu.

topic_model <- bt$BERTopic(language = "english", verbose = TRUE)

output <- topic_model$fit_transform(d_newsgroups, embeddings = embeddings)

tbl_output <- tibble(
  document = seq_along(d_newsgroups),
  topic = output[[1]],
  prob = output[[2]]
)

Výsledky

Základní přehled témat, počtů příslušných dokumentů a shrnutí clusterů získáte skrze:

# náhled omezíme jen na 5 prvních témat a z plných ukázek reprezentativních dokumentů
# uděláme úryvky o 100 znacích (vše kvůli renderování reportu, pro práci v RStudiu
# není třeba)

topic_model$get_topic_info() |>
  head() |>
  mutate(Representative_Docs = map(Representative_Docs, \(x) str_trunc(x, 100)))

Nejvíce zastoupené téma je pod indexem 0. Podrobnosti o tématu získáme příkazem níže. Výstupem jsou slova s nejvyšší c-TF-IDF hodnotou, zobrazenou u každého slova. c-TF-IDF je TF-IDF fungující ne na úrovni dokumentů, ale clusterů. Říká, jaká slova jsou nejvíce relevantní pro dané téma; vychází ze “slepence” dokumentů v rámci clusteru.

topic_model$get_topic(4) |> tuple_list_to_tibble()
topic_model$visualize_barchart()

UMAP do 2 dimenzí

topic_model$visualize_documents(d_newsgroups, embeddings = embeddings, topics = tuple(as.list(1:10)))
topic_model$visualize_heatmap(top_n_topics = 50L)
topic_model$visualize_hierarchy(orientation = "left", top_n_topics = 10L)
A marker object has been specified, but markers is not in the mode
Adding markers to the mode...
A marker object has been specified, but markers is not in the mode
Adding markers to the mode...
A marker object has been specified, but markers is not in the mode
Adding markers to the mode...
A marker object has been specified, but markers is not in the mode
Adding markers to the mode...
A marker object has been specified, but markers is not in the mode
Adding markers to the mode...
A marker object has been specified, but markers is not in the mode
Adding markers to the mode...
A marker object has been specified, but markers is not in the mode
Adding markers to the mode...
A marker object has been specified, but markers is not in the mode
Adding markers to the mode...
A marker object has been specified, but markers is not in the mode
Adding markers to the mode...

Můžeme srovnat s nějakými externími informacemi. Tady např. víme, v jakém novinkovém kanálu se zpráva objevila. Můžeme použít i údaj o výzkumné skupině atp.

classes_nms <- newsgroups_raw$target_names
classes_idx <- newsgroups_raw$target + 1L # prevent zeros


classes_fct <- factor(classes_idx, labels = classes_nms)

topics_per_class <- topic_model$topics_per_class(d_newsgroups, classes_fct)

topic_model$visualize_topics_per_class(topics_per_class = topics_per_class)

Embeddings, UMAP a HDBSCAN necháme, jak jsme je “nafitovali”, ale zkusíme poladit reprezentace clusterů.

Nejprve zkusíme vyřadit anglická stop slova, hledat a spojovat slovní spojení až o třech slovech a brát v potaz jen slova, co se vyskytnou alespoň 20krát:

vectorizer_model <-  cv$CountVectorizer(stop_words="english", ngram_range = tuple(1L, 3L), min_df = 20L)

topic_model$update_topics(d_newsgroups, vectorizer_model=vectorizer_model)
topic_model$visualize_barchart()
# representation_model_keybert <- bt_repre$KeyBERTInspired()
# topic_model$update_topics(d_newsgroups, vectorizer_model=vectorizer_model, representation_model = representation_model_keybert)

Hrát si můžeme i se samotnými popisky. Tento krok se defaultně nepoužívá. Zkusíme model od Facebooku, který umí obsah témat klasifikovat do kategorií, na kterých ani nemusel být trénovaný (tzv. “zero-shot” klasifikace). Je ale třeba vymyslet sadu kandidátních popisků (a doufat, že to náš počítač upočítá).

candidate_topics <- c("computers", "bicycles", "cars", "sport", "palestine", "health")
representation_model <- bt_repre$ZeroShotClassification(candidate_topics = candidate_topics, model="facebook/bart-large-mnli") 

# updatujeme jenom repre model
topic_model$update_topics(d_newsgroups, representation_model = representation_model)
topic_model$get_topic_info() 
topic_model$visualize_barchart()

Pozn: máme-li peníze a API na OpenAI, můžeme o shrnutí témat poprosit ChatGPT.